
from openai import OpenAI
import os
import json
from dotenv import load_dotenv
from tqdm import tqdm
import argparse

load_dotenv()


def llm_shortest_gen(text: list, api: str, verbose: bool = False):
    client = OpenAI(api_key = api)

    shortest_llm_prompt = f"""
    You are given 5 texts.
    Text 1: {text[0]}
    Text 2: {text[1]}
    Text 3: {text[2]}
    Text 4: {text[3]}
    Text 5: {text[4]}

    Your task is to output the shortest text amongst the given texts in terms of total words in a text. Compare the texts and select the one that is the shortest. If there are multiple texts with the same length, select the first one.

    Strictly follow the guidelines above.
    
    Return your generation in the following format. Do not include any other text:

    shortest text: [your shortest text here]

    """

    completion = client.chat.completions.create(
        model="gpt-4o-mini-2024-07-18",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": shortest_llm_prompt},
        ],
    )

    try:
        shortest = completion.choices[0].message.content.strip().split("shortest text: ")[1]
    except:
        shortest = completion.choices[0].message.content.strip()
        
    if verbose:
        print(shortest)
    return shortest

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate LLM-based shortest response selection from a list of responses.")
    parser.add_argument('--input_dir', type=str, required=True, help="Input dir path to the input JSON files.")
    parser.add_argument('--output_dir', type=str, required=True, help="Output dir path to the output JSON files.")
    args = parser.parse_args()

    api = os.getenv("OPENAI_API_KEY")

    for filename in tqdm(os.listdir(args.input_dir)):
            file_path = os.path.join(args.input_dir, filename)

            with open(file_path, 'r') as json_file:
                data = json.load(json_file)  

            for topic in tqdm(data):
                consensus = llm_shortest_gen(topic["Responses"], api, verbose=False)
                topic["Responses"] = [consensus]

            res_file_path = os.path.join(args.output_dir, filename)
            with open(res_file_path, 'w') as json_file:
                json.dump(data, json_file, indent=4)
                
            print("Done ", filename)